import torch.nn.functional as F
import torch
import torchvision
from torchvision import datasets, transforms
from matplotlib import pyplot as plt
import gc
import psutil
import os
import tracemalloc
import torch.optim as optim
from gzymodel import ResNet18
import linecache
def log_memory_usage():
    process = psutil.Process(os.getpid())
    print(f"Memory usage: {process.memory_info().rss / 1024 ** 2:.2f} MB")
    print(f"CUDA memory allocated: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
    print(f"CUDA memory reserved: {torch.cuda.memory_reserved() / 1024 ** 2:.2f} MB")

def display_top(snapshot, key_type='lineno', limit=10):
    snapshot = snapshot.filter_traces((
        tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
        tracemalloc.Filter(False, "<unknown>"),
    ))
    top_stats = snapshot.statistics(key_type)

    print("Top %s lines" % limit)
    for index, stat in enumerate(top_stats[:limit], 1):
        frame = stat.traceback[0]
        print("#%s: %s:%s: %.1f KiB"
              % (index, frame.filename, frame.lineno, stat.size / 1024))
        line = linecache.getline(frame.filename, frame.lineno).strip()
        if line:
            print('    %s' % line)

    other = top_stats[limit:]
    if other:
        size = sum(stat.size for stat in other)
        print("%s other: %.1f KiB" % (len(other), size / 1024))
    total = sum(stat.size for stat in top_stats)
    print("Total allocated size: %.1f KiB" % (total / 1024))

def load_dataset():
    '''Prepare and Load Data Set'''
    pipline_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),  # random rotate figures
        # transforms.Resize((32, 32)),       # modify the figure size to 32x32
        transforms.ToTensor(),  # turn the figure to tensor type
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # normalize figures
    ])

    pipline_test = transforms.Compose([
        # transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    train_set = torchvision.datasets.CIFAR10(root="../data", train=True, download=True, transform=pipline_train)
    test_set = torchvision.datasets.CIFAR10(root="../data", train=False, download=True, transform=pipline_test)

    trainloader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True
                                              # num_workers=2
                                              )
    testloader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False
                                             # num_workers=2
                                             )
    return trainloader, testloader

def save_and_reset(model, optimizer, file_path, device, iteration,qat=False,nat=False):
    iteration_file_path = f"{file_path}_tempsave.pth"
    torch.save(model.state_dict(), iteration_file_path)
    print(f"Model saved to {iteration_file_path}")
    del model
    torch.cuda.empty_cache()
    gc.collect()
    model = ResNet18(qat=qat,nat=nat).to(device)
    model.load_state_dict(torch.load(iteration_file_path))
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    print("Model reloaded and environment reset")
    return model, optimizer

def train_runner(model, device, trainloader, optimizer, file_path,qat,nat,save_interval = 1000):
    tracemalloc.start()
    model.train()
    total = 0.0
    correct = 0.0
    # save_interval = 50
    running_loss = 0.0  # 用于累加损失

    for iteration, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)

        loss = F.cross_entropy(outputs, labels)
        running_loss += loss.item() * labels.size(0)  # 累加每个批次的损失值，乘以批次大小

        predict = outputs.argmax(dim=-1)
        total += labels.size(0)
        correct += (predict == labels).sum().item()

        loss.backward()
        optimizer.step()
        '''
        if i % 1000 == 0:
            print("i: ", i)
            print("Train Epoch{} \t Loss: {:.6f}, accuracy: {:.6f}%".format(epoch, loss.item(), 100*(correct/total)))
            Loss.append(loss.item())
            Accuracy.append(correct/total)
        '''
        if (iteration % save_interval)==0 and (iteration!=0):
            avg_loss = running_loss / total  # 计算每个样本的平均损失
            accuracy = correct / total  # 计算准确率
            print(f"Saving model at iteration {iteration}: Loss: {avg_loss}, Accuracy: {100 * accuracy:.2f}%")
            model, optimizer = save_and_reset(model, optimizer, file_path, device, iteration,qat,nat)
            # 重置统计变量,可能有问题
            running_loss = 0.0
            total = 0.0
            correct = 0.0


    avg_loss = running_loss / total  # 计算每个样本的平均损失
    accuracy = correct / total  # 计算准确率

    # file_path = '../model_params/witin/float/ResNet18_param_' + str(epoch) + '.pth'
    torch.save(model.state_dict(), file_path)
    print(f"Final Loss: {running_loss / total}, Final Accuracy: {100 * (correct / total):.2f}%")
    log_memory_usage()
    tracemalloc.stop()  # 
    return avg_loss, accuracy


def test_runner(model, device, testloader, use_quantization=False, scale=1, zero_point=0):
    model.eval()
    correct = 0.0
    test_loss = 0.0
    total = 0.0

    with torch.no_grad():
        for data, label in testloader:
            data, label = data.to(device), label.to(device)
            if(use_quantization==True):
                data = quantize(data, scale, zero_point)
                output = model.quantize_inference(data)
                output = dequantize(output, scale, zero_point)
            else:
                output = model(data)
            # print(output)
            # print(label)
            loss = F.cross_entropy(output, label)
            test_loss += loss.item() * label.size(0)  # 累加损失，乘以批次大小
            predict = output.argmax(dim=1)
            total += label.size(0)
            correct += (predict == label).sum().item()
        # print("test_average_loss: {:.6f}, accuracy: {:.6f}%".format(test_loss/total, 100*(correct/total)))

    avg_test_loss = test_loss / total  # 计算每个样本的平均损失
    accuracy = correct / total  # 计算测试集的准确率
    return avg_test_loss, accuracy


def fig_plot(Loss, test_Loss, Accuracy, test_Accuracy):
    # 绘制 Loss 曲线
    plt.figure(figsize=(10, 6))  # 设置图像大小

    plt.subplot(2, 1, 1)
    plt.plot(Loss, label='Train Loss', linewidth=2)
    plt.plot(test_Loss, label='Test Loss', linewidth=2)
    plt.title('Loss', fontsize=14)
    plt.xlabel('Epochs', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.legend()

    # 绘制 Accuracy 曲线
    plt.subplot(2, 1, 2)
    plt.plot(Accuracy, label='Train Accuracy', linewidth=2)
    plt.plot(test_Accuracy, label='Test Accuracy', linewidth=2)
    plt.title('Accuracy', fontsize=14)
    plt.xlabel('Epochs', fontsize=12)
    plt.ylabel('Accuracy', fontsize=12)
    plt.legend()

    # 调整布局，避免文字重叠
    plt.tight_layout()

    # 保存图像
    plt.savefig('metrics.png')
    # plt.show()  # 显示图像

def compute_full_dataset_quantization_params(data_loader, bit_width=8):
    all_min = float('inf')
    all_max = float('-inf')

    # 遍历所有数据，找出整个数据集的最小值和最大值
    for inputs, labels in data_loader:
        batch_min = inputs.min().item()
        batch_max = inputs.max().item()

        # 更新整个数据集的最小值和最大值
        all_min = min(all_min, batch_min)
        all_max = max(all_max, batch_max)
    print('all_max, all_min: ', all_max, all_min)
    # 基于数据集的最小值和最大值计算量化参数
    scale = (all_max - all_min) / (2 ** bit_width - 1)
    zero_point = torch.tensor((all_max+all_min)/2)
    zero_point = torch.round(zero_point)
    print('scale, zero_point: ', scale, zero_point)
    return scale, zero_point

def quantize(x, scale, zero_point, min_val=-128, max_val=127):
    x_quant = torch.round(x / scale + zero_point)
    x_quant = torch.clamp(x_quant, min_val, max_val)  # 保证量化后的值在[min_val, max_val]范围内
    return x_quant

def dequantize(x, scale, zero_point):
    return scale * (x - zero_point)